package org.hepeng.workx.spring.session.web.http;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.hepeng.workx.extension.XLoader;
import org.hepeng.workx.util.proxy.Invocation;
import org.hepeng.workx.util.proxy.InvokeFilter;
import org.hepeng.workx.util.proxy.Invoker;
import org.hepeng.workx.util.proxy.ProxyFactory;
import org.hepeng.workx.util.proxy.TargetProxyMethodInvokerFilter;
import org.joor.Reflect;
import org.springframework.util.ClassUtils;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

/**
 * @author he peng
 */
public class HttpSessionIdUtils {

    private static final ProxyFactory PROXY_FACTORY = XLoader.getXLoader(ProxyFactory.class).getX();

    private static Invoker headerHttpSessionIdResolverInvoker = new Invoker() {
        @Override
        public Object invoke(Invocation invocation) throws Throwable {
            List<String> sessionIds = (List<String>) invocation.invoke();
            if (CollectionUtils.isEmpty(sessionIds)) {
                HttpServletRequest request = (HttpServletRequest) invocation.getArgs()[0];
                String val = request.getParameter(Reflect.on(invocation.getNative()).get("headerName"));
                sessionIds = val != null ? Collections.singletonList(val)
                        : Collections.emptyList();
            }
            return sessionIds;
        }
    };

    private static Invoker headerHttpSessionStrategyInvoker = new Invoker() {
        @Override
        public Object invoke(Invocation invocation) throws Throwable {
            String methodName = invocation.getMethod().getName();
            if (StringUtils.equals("setHeaderName" , methodName)) {
                invocation.invoke();
            }

            String sessionId = "";
            if (StringUtils.equals("getRequestedSessionId" , methodName)) {
                sessionId = (String) invocation.invoke();
                if (StringUtils.isBlank(sessionId)) {
                    HttpServletRequest request = (HttpServletRequest) invocation.getArgs()[0];
                    sessionId = request.getParameter(Reflect.on(invocation.getNative()).get("headerName"));
                }
            }
            return sessionId;
        }
    };

    public static Object createParameterAndHeaderSessionIdExtractor(String headerName) {

        try {
            Class<?> clazz1 = ClassUtils.forName("org.springframework.session.web.http.HeaderHttpSessionIdResolver", null);
            if (Objects.nonNull(clazz1)) {
                List<Class<?>> interfaces = Arrays.asList(clazz1.getInterfaces());
                List<Class<?>> constructorArgTypes = new ArrayList<>();
                constructorArgTypes.add(String.class);
                List<Object> constructorArgs = new ArrayList<>();
                constructorArgs.add(headerName);
                List<Invoker> invokers = new ArrayList<>();
                invokers.add(headerHttpSessionIdResolverInvoker);
                List<InvokeFilter> filters = new ArrayList<>();

                Set<Method> targetProxyMethods = new HashSet<>();
                targetProxyMethods.add(clazz1.getMethod("resolveSessionIds" , HttpServletRequest.class));
                filters.add(new TargetProxyMethodInvokerFilter(targetProxyMethods));
                return PROXY_FACTORY.createProxy(clazz1 , interfaces , constructorArgTypes , constructorArgs , invokers , filters);
            }

        } catch (Exception e) {
            try {
                Class<?> clazz2 = ClassUtils.forName("org.springframework.session.web.http.HeaderHttpSessionStrategy", null);
                if (Objects.nonNull(clazz2)) {
                    List<Class<?>> interfaces = Arrays.asList(clazz2.getInterfaces());
                    List<Invoker> invokers = new ArrayList<>();
                    invokers.add(headerHttpSessionStrategyInvoker);
                    List<InvokeFilter> filters = new ArrayList<>();
                    Set<Method> targetProxyMethods = new HashSet<>();
                    targetProxyMethods.add(clazz2.getMethod("getRequestedSessionId" , HttpServletRequest.class));
                    targetProxyMethods.add(clazz2.getMethod("setHeaderName" , String.class));
                    filters.add(new TargetProxyMethodInvokerFilter(targetProxyMethods));
                    Object proxy = PROXY_FACTORY.createProxy(clazz2, interfaces, null , null, invokers, filters);
                    Reflect.on(proxy).call("setHeaderName" , headerName);
                    return proxy;
                }
            } catch (Exception e1) {
                e1.printStackTrace();
            }

        }
        return null;
    }
}
